{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extending pgmpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "It's really easy to extend pgmpy to quickly prototype your ideas. pgmpy has a base abstract class for most of main functionalities like: `BaseInference` for inference, `BaseFactor` for model parameters, `BaseEstimators` for parameter and model learning. For adding a new feature to pgmpy we just need to implement a new class inheriting one of these base classes and then we can use the new class with other functionality of pgmpy.\n",
    "\n",
    "In this example we will see how to write a new inference algorithm. We will take the example of a very simple algorithm in which we will multiply all the factors/CPD of the network and marginalize over variable to get the desired query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A simple Exact inference algorithm\n",
    "\n",
    "import itertools\n",
    "\n",
    "from pgmpy.inference.base import Inference\n",
    "from pgmpy.factors import factor_product\n",
    "\n",
    "\n",
    "class SimpleInference(Inference):\n",
    "    def __init__(self,model):\n",
    "        super(SimpleInference, self).__init__(model)\n",
    "        self._initialize_structures()\n",
    "        \n",
    "    # By inheriting Inference we can use self.model, self.factors and self.cardinality in our class\n",
    "    def query(self, var, evidence):\n",
    "        # self.factors is a dict of the form of {node: [factors_involving_node]}\n",
    "        factors_list = set(itertools.chain(*self.factors.values()))\n",
    "        product = factor_product(*factors_list)\n",
    "        reduced_prod = product.reduce(evidence, inplace=False)\n",
    "        reduced_prod.normalize()\n",
    "        var_to_marg = (\n",
    "            set(self.model.nodes()) - set(var) - set([state[0] for state in evidence])\n",
    "        )\n",
    "        marg_prod = reduced_prod.marginalize(var_to_marg, inplace=False)\n",
    "        return marg_prod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Defining a model\n",
    "\n",
    "from pgmpy.models import DiscreteBayesianNetwork\n",
    "from pgmpy.factors.discrete import TabularCPD\n",
    "\n",
    "model = DiscreteBayesianNetwork([(\"A\", \"J\"), (\"R\", \"J\"), (\"J\", \"Q\"), (\"J\", \"L\"), (\"G\", \"L\")])\n",
    "cpd_a = TabularCPD(\"A\", 2, values=[[0.2], [0.8]])\n",
    "cpd_r = TabularCPD(\"R\", 2, values=[[0.4], [0.6]])\n",
    "cpd_j = TabularCPD(\n",
    "    \"J\",\n",
    "    2,\n",
    "    values=[[0.9, 0.6, 0.7, 0.1], [0.1, 0.4, 0.3, 0.9]],\n",
    "    evidence=[\"A\", \"R\"],\n",
    "    evidence_card=[2, 2],\n",
    ")\n",
    "cpd_q = TabularCPD(\n",
    "    \"Q\", 2, values=[[0.9, 0.2], [0.1, 0.8]], evidence=[\"J\"], evidence_card=[2]\n",
    ")\n",
    "cpd_l = TabularCPD(\n",
    "    \"L\",\n",
    "    2,\n",
    "    values=[[0.9, 0.45, 0.8, 0.1], [0.1, 0.55, 0.2, 0.9]],\n",
    "    evidence=[\"J\", \"G\"],\n",
    "    evidence_card=[2, 2],\n",
    ")\n",
    "cpd_g = TabularCPD(\"G\", 2, values=[[0.6], [0.4]])\n",
    "\n",
    "model.add_cpds(cpd_a, cpd_r, cpd_j, cpd_q, cpd_l, cpd_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Doing inference with our SimpleInference\n",
    "\n",
    "infer = SimpleInference(model)\n",
    "a = infer.query(var=[\"A\"], evidence=[(\"J\", 0), (\"R\", 1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+\n",
      "| A   |   phi(A) |\n",
      "|-----+----------|\n",
      "| A_0 |   0.6000 |\n",
      "| A_1 |   0.4000 |\n",
      "+-----+----------+\n"
     ]
    }
   ],
   "source": [
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+\n",
      "| A   |   phi(A) |\n",
      "|-----+----------|\n",
      "| A_0 |   0.6000 |\n",
      "| A_1 |   0.4000 |\n",
      "+-----+----------+\n"
     ]
    }
   ],
   "source": [
    "# Comparing the results with Variable Elimination Algorithm\n",
    "\n",
    "from pgmpy.inference import VariableElimination\n",
    "\n",
    "infer = VariableElimination(model)\n",
    "result = infer.query([\"A\"], evidence={\"J\": 0, \"R\": 1})\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly we can also create new classes for Factor or CPDs and add them to networks and do inference over it or can write a new estimator class."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
